1. Introduction

Problem Description and Objectives

According to the World Health Organization, stroke is the second leading cause of death globally and, according to the Heart Disease and Stroke Statistics 2019 report, stroke is the fifth leading cause of death in the United States. Additionally, the most recent HDSS report shows that someone has a stroke in the United States every 40 seconds and someone dies from a stroke every 3.5 minutes. Due to the prevalence and seriousness of the heart disease condition, being able to predict one’s likelihood of suffering from a stroke prior could be helpful in assessing risk and evaluating treatment plans accordingly.

Using the “healthcare-dataset-stroke-data” from Kaggle, we are curious to see which variables are associated with, first, a patient having a stroke and then second, if we can find a model to predict whether a patient will or will not have a stroke. While previous research shows that age, heart disease, average glucose level and hypertension are most important factors for stroke prediction, the dataset we are using contains all of these and also many other variables that may reveal interesting patterns.

Dataset and Data Mining Task

The data we are using for this project is from Kaggle and contains 5,100 observations with twelve attributes: id, gender, age, if hypertension is present or not, if heart disease is present or not, if they have ever been married, what type of work they do, where they reside (rural or urban), their average glucose level, BMI, their smoking status and whether or not they had a stroke. Each row of data corresponds to one patient:

Attribute Information

  1. id: unique identifier

  2. gender: “Male”, “Female” or “Other”

  3. age: age of the patient

  4. hypertension: 0 if the patient doesn’t have hypertension, 1 if the patient has hypertension

  5. heart_disease: 0 if the patient doesn’t have any heart diseases, 1 if the patient has a heart disease

  6. ever_married: “No” or “Yes”

  7. work_type: “children”, “Govt_jov”, “Never_worked”, “Private” or “Self-employed”

  8. Residence_type: “Rural” or “Urban”

  9. avg_glucose_level: average glucose level in blood

  10. bmi: body mass index

  11. smoking_status: “formerly smoked”, “never smoked”, “smokes” or “Unknown”*

  12. stroke: 1 if the patient had a stroke or 0 if not

2. Preparation

# data
library(dplyr)
library(ggplot2)
library(Amelia)
library(corrplot)
library(corrgram)

# model
library(caTools)
library(e1071)
library(caret)
library(ROSE)
library(Metrics)
library(class)
library(tidymodels)
library(glmnet)
# read the data
data <- read.csv("healthcare-dataset-stroke-data.csv")

# view the data
head(data)
##      id gender age hypertension heart_disease ever_married     work_type
## 1  9046   Male  67            0             1          Yes       Private
## 2 51676 Female  61            0             0          Yes Self-employed
## 3 31112   Male  80            0             1          Yes       Private
## 4 60182 Female  49            0             0          Yes       Private
## 5  1665 Female  79            1             0          Yes Self-employed
## 6 56669   Male  81            0             0          Yes       Private
##   Residence_type avg_glucose_level  bmi  smoking_status stroke
## 1          Urban            228.69 36.6 formerly smoked      1
## 2          Rural            202.21  N/A    never smoked      1
## 3          Rural            105.92 32.5    never smoked      1
## 4          Urban            171.23 34.4          smokes      1
## 5          Rural            174.12   24    never smoked      1
## 6          Urban            186.21   29 formerly smoked      1
str(data)
## 'data.frame':    5110 obs. of  12 variables:
##  $ id               : int  9046 51676 31112 60182 1665 56669 53882 10434 27419 60491 ...
##  $ gender           : chr  "Male" "Female" "Male" "Female" ...
##  $ age              : num  67 61 80 49 79 81 74 69 59 78 ...
##  $ hypertension     : int  0 0 0 0 1 0 1 0 0 0 ...
##  $ heart_disease    : int  1 0 1 0 0 0 1 0 0 0 ...
##  $ ever_married     : chr  "Yes" "Yes" "Yes" "Yes" ...
##  $ work_type        : chr  "Private" "Self-employed" "Private" "Private" ...
##  $ Residence_type   : chr  "Urban" "Rural" "Rural" "Urban" ...
##  $ avg_glucose_level: num  229 202 106 171 174 ...
##  $ bmi              : chr  "36.6" "N/A" "32.5" "34.4" ...
##  $ smoking_status   : chr  "formerly smoked" "never smoked" "never smoked" "smokes" ...
##  $ stroke           : int  1 1 1 1 1 1 1 1 1 1 ...
summary(data)
##        id           gender               age         hypertension    
##  Min.   :   67   Length:5110        Min.   : 0.08   Min.   :0.00000  
##  1st Qu.:17741   Class :character   1st Qu.:25.00   1st Qu.:0.00000  
##  Median :36932   Mode  :character   Median :45.00   Median :0.00000  
##  Mean   :36518                      Mean   :43.23   Mean   :0.09746  
##  3rd Qu.:54682                      3rd Qu.:61.00   3rd Qu.:0.00000  
##  Max.   :72940                      Max.   :82.00   Max.   :1.00000  
##  heart_disease     ever_married        work_type         Residence_type    
##  Min.   :0.00000   Length:5110        Length:5110        Length:5110       
##  1st Qu.:0.00000   Class :character   Class :character   Class :character  
##  Median :0.00000   Mode  :character   Mode  :character   Mode  :character  
##  Mean   :0.05401                                                           
##  3rd Qu.:0.00000                                                           
##  Max.   :1.00000                                                           
##  avg_glucose_level     bmi            smoking_status         stroke       
##  Min.   : 55.12    Length:5110        Length:5110        Min.   :0.00000  
##  1st Qu.: 77.25    Class :character   Class :character   1st Qu.:0.00000  
##  Median : 91.89    Mode  :character   Mode  :character   Median :0.00000  
##  Mean   :106.15                                          Mean   :0.04873  
##  3rd Qu.:114.09                                          3rd Qu.:0.00000  
##  Max.   :271.74                                          Max.   :1.00000

3. Data Cleaning

# check NA values
any(is.na(data)) # it shows there are no missing value
## [1] FALSE
# however there are N/A in bmi, convert them to NA values
data[data == 'N/A'] <- NA
missmap(data, col=c("yellow", "black"), legend=FALSE) # there are missing values on bmi

table(is.na(data)) # there are 201 missing values
## 
## FALSE  TRUE 
## 61119   201
# convert bmi data type to numeric
data$bmi <- as.numeric(data$bmi)

# plot bmi
hist(data$bmi)

boxplot(data$bmi)

# drop NA values
data <- na.omit(data)
any(is.na(data)) # check again, there is no NA value now
## [1] FALSE
# drop the id column
data <- data[-1]
head(data)
##   gender age hypertension heart_disease ever_married     work_type
## 1   Male  67            0             1          Yes       Private
## 3   Male  80            0             1          Yes       Private
## 4 Female  49            0             0          Yes       Private
## 5 Female  79            1             0          Yes Self-employed
## 6   Male  81            0             0          Yes       Private
## 7   Male  74            1             1          Yes       Private
##   Residence_type avg_glucose_level  bmi  smoking_status stroke
## 1          Urban            228.69 36.6 formerly smoked      1
## 3          Rural            105.92 32.5    never smoked      1
## 4          Urban            171.23 34.4          smokes      1
## 5          Rural            174.12 24.0    never smoked      1
## 6          Urban            186.21 29.0 formerly smoked      1
## 7          Rural             70.09 27.4    never smoked      1
# data transformation
str(data)
## 'data.frame':    4909 obs. of  11 variables:
##  $ gender           : chr  "Male" "Male" "Female" "Female" ...
##  $ age              : num  67 80 49 79 81 74 69 78 81 61 ...
##  $ hypertension     : int  0 0 0 1 0 1 0 0 1 0 ...
##  $ heart_disease    : int  1 1 0 0 0 1 0 0 0 1 ...
##  $ ever_married     : chr  "Yes" "Yes" "Yes" "Yes" ...
##  $ work_type        : chr  "Private" "Private" "Private" "Self-employed" ...
##  $ Residence_type   : chr  "Urban" "Rural" "Urban" "Rural" ...
##  $ avg_glucose_level: num  229 106 171 174 186 ...
##  $ bmi              : num  36.6 32.5 34.4 24 29 27.4 22.8 24.2 29.7 36.8 ...
##  $ smoking_status   : chr  "formerly smoked" "never smoked" "smokes" "never smoked" ...
##  $ stroke           : int  1 1 1 1 1 1 1 1 1 1 ...
# convert character data type to factor
data <- data %>% mutate(across(where(is.character),factor))

# convert hypertension, heart_disease, stroke data type from integer to factor
data$hypertension <- as.factor(data$hypertension)
data$heart_disease <- as.factor(data$heart_disease)
data$stroke <- as.factor(data$stroke)

# binning numeric valuables
# age
ggplot(data, aes(age,y=..density..)) + 
  geom_histogram(binwidth=1,
                 color="black",
                 fill="#02bcfa",
                 alpha=0.5) + 
  geom_density() + labs(title="Age Distribution")
## Warning: The dot-dot notation (`..density..`) was deprecated in ggplot2 3.4.0.
## ℹ Please use `after_stat(density)` instead.

boxplot(data$age)

summary(data$age)
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##    0.08   25.00   44.00   42.87   60.00   82.00
# binning age with quantile: 25, 44, 60, 82
data$age <- cut(data$age,
                breaks = c(0, 25, 44, 60, 82), 
                labels=c('young', 'grown', 'mature', 'old'))

# avg glucose level
ggplot(data, aes(avg_glucose_level, y=..density..)) + 
  geom_histogram(color="black",
                 fill="#02bcfa",
                 alpha=0.5) + 
  geom_density() + 
  labs(title="Average Glucose Level Distribution")
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

boxplot(data$avg_glucose_level)

summary(data$avg_glucose_level)
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##   55.12   77.07   91.68  105.31  113.57  271.74
# binning avg glucose level based on the information on website:
# https://my.clevelandclinic.org/health/diagnostics/12363-blood-glucose-test#:~:text=What%20is%20a%20normal%20glucose,can%20be%20%E2%80%9Cnormal%E2%80%9D%20too.
group_glucose <- function(level){
  res <- level
  for (i in 1:length(level)){
    if (level[i] <= 70){
      res[i] <- "low"
    } else if (level[i] > 70 & level[i] <= 99) {
      res[i] <- "normal"
    } else if (level[i] > 100 & level[i] <= 125) {
      res[i] <- "prediabetes"
    } else {
      res[i] <- "diabetes"
    }
  }
  return(res)
}

# apply group_glucose function
data$avg_glucose_level <- group_glucose(data$avg_glucose_level)

# convert avg_glucose_level data type to factor
data$avg_glucose_level <- as.factor(data$avg_glucose_level)

# levels of data$avg_glucose_level are in the wrong order
levels(data$avg_glucose_level)
## [1] "diabetes"    "low"         "normal"      "prediabetes"
# reorder the levels of data$avg_glucose_level
data$avg_glucose_level <- factor(data$avg_glucose_level, levels = c("low", "normal", "prediabetes", "diabetes"))

# check again
levels(data$avg_glucose_level)
## [1] "low"         "normal"      "prediabetes" "diabetes"
# bmi
ggplot(data, aes(bmi)) + 
  geom_histogram(color="black",
                 fill="#02bcfa",
                 alpha=0.5) + 
  geom_density() + 
  labs(title="BMI Distribution")
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

boxplot(data$bmi)

summary(data$bmi)
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##   10.30   23.50   28.10   28.89   33.10   97.60
# binning BMI based on the information on the CDC website:
# https://www.cdc.gov/healthyweight/assessing/index.html#:~:text=If%20your%20BMI%20is%20less,falls%20within%20the%20obese%20range.
group_bmi <- function(bmi){
  res <- bmi
  for (i in 1:length(bmi)){
    if (bmi[i] < 18.5){
      res[i] <- "underweight"
    } else if (bmi[i] >= 18.5 & bmi[i] <= 24.9) {
      res[i] <- "normal"
    } else if (bmi[i] >= 25.0 & bmi[i] <= 29.9) {
      res[i] <- "overweight"
    } else {
      res[i] <- "obese"
    }
  }
  return(res)
}

# apply group_bmi function
data$bmi <- group_bmi(data$bmi)

# convert bmi data type to factor
data$bmi <- as.factor(data$bmi)

# levels of data$bmi are in the wrong order
levels(data$bmi)
## [1] "normal"      "obese"       "overweight"  "underweight"
# reorder the levels of bmi
data$bmi <- factor(data$bmi, levels = c("underweight", "normal", "overweight", "obese"))

# check the structure
str(data)
## 'data.frame':    4909 obs. of  11 variables:
##  $ gender           : Factor w/ 3 levels "Female","Male",..: 2 2 1 1 2 2 1 1 1 1 ...
##  $ age              : Factor w/ 4 levels "young","grown",..: 4 4 3 4 4 4 4 4 4 4 ...
##  $ hypertension     : Factor w/ 2 levels "0","1": 1 1 1 2 1 2 1 1 2 1 ...
##  $ heart_disease    : Factor w/ 2 levels "0","1": 2 2 1 1 1 2 1 1 1 2 ...
##  $ ever_married     : Factor w/ 2 levels "No","Yes": 2 2 2 2 2 2 1 2 2 2 ...
##  $ work_type        : Factor w/ 5 levels "children","Govt_job",..: 4 4 4 5 4 4 4 4 4 2 ...
##  $ Residence_type   : Factor w/ 2 levels "Rural","Urban": 2 1 2 1 2 1 2 2 1 1 ...
##  $ avg_glucose_level: Factor w/ 4 levels "low","normal",..: 4 3 4 4 4 2 2 1 2 3 ...
##  $ bmi              : Factor w/ 4 levels "underweight",..: 4 4 4 2 3 3 2 2 3 4 ...
##  $ smoking_status   : Factor w/ 4 levels "formerly smoked",..: 1 2 3 2 1 2 2 4 2 3 ...
##  $ stroke           : Factor w/ 2 levels "0","1": 2 2 2 2 2 2 2 2 2 2 ...

4. Exploratory Data Analysis

# check the correlation
# convert factor variables to numeric variables
data_num <- data %>% mutate(across(where(is.factor),as.numeric))
str(data_num)
## 'data.frame':    4909 obs. of  11 variables:
##  $ gender           : num  2 2 1 1 2 2 1 1 1 1 ...
##  $ age              : num  4 4 3 4 4 4 4 4 4 4 ...
##  $ hypertension     : num  1 1 1 2 1 2 1 1 2 1 ...
##  $ heart_disease    : num  2 2 1 1 1 2 1 1 1 2 ...
##  $ ever_married     : num  2 2 2 2 2 2 1 2 2 2 ...
##  $ work_type        : num  4 4 4 5 4 4 4 4 4 2 ...
##  $ Residence_type   : num  2 1 2 1 2 1 2 2 1 1 ...
##  $ avg_glucose_level: num  4 3 4 4 4 2 2 1 2 3 ...
##  $ bmi              : num  4 4 4 2 3 3 2 2 3 4 ...
##  $ smoking_status   : num  1 2 3 2 1 2 2 4 2 3 ...
##  $ stroke           : num  2 2 2 2 2 2 2 2 2 2 ...
# correlation and corrplot
(cor <- cor(data_num))
##                         gender         age hypertension heart_disease
## gender             1.000000000 -0.01692503  0.021578286   0.082711652
## age               -0.016925028  1.00000000  0.270583325   0.251216361
## hypertension       0.021578286  0.27058332  1.000000000   0.115990991
## heart_disease      0.082711652  0.25121636  0.115990991   1.000000000
## ever_married      -0.037236693  0.66476920  0.162406260   0.111245121
## work_type         -0.072538268  0.45266958  0.124654706   0.092144819
## Residence_type    -0.005013763  0.01457606 -0.001074146  -0.002361744
## avg_glucose_level  0.040858382  0.14965716  0.124198217   0.106454582
## bmi                0.020621827  0.38468598  0.159444545   0.085086762
## smoking_status     0.038252248 -0.34128144 -0.132831660  -0.071396924
## stroke             0.006757363  0.21921811  0.142514606   0.137937788
##                   ever_married     work_type Residence_type avg_glucose_level
## gender            -0.037236693 -0.0725382684  -0.0050137626       0.040858382
## age                0.664769200  0.4526695779   0.0145760598       0.149657159
## hypertension       0.162406260  0.1246547061  -0.0010741462       0.124198217
## heart_disease      0.111245121  0.0921448190  -0.0023617439       0.106454582
## ever_married       1.000000000  0.4259143556   0.0049891711       0.096121554
## work_type          0.425914356  1.0000000000  -0.0008827106       0.054018824
## Residence_type     0.004989171 -0.0008827106   1.0000000000      -0.007651788
## avg_glucose_level  0.096121554  0.0540188244  -0.0076517877       1.000000000
## bmi                0.404389675  0.4020328198  -0.0074476403       0.112484272
## smoking_status    -0.310702330 -0.3444032458   0.0027191093      -0.065479200
## stroke             0.105089144  0.0797450467   0.0060314265       0.094897697
##                           bmi smoking_status       stroke
## gender             0.02062183    0.038252248  0.006757363
## age                0.38468598   -0.341281444  0.219218106
## hypertension       0.15944454   -0.132831660  0.142514606
## heart_disease      0.08508676   -0.071396924  0.137937788
## ever_married       0.40438967   -0.310702330  0.105089144
## work_type          0.40203282   -0.344403246  0.079745047
## Residence_type    -0.00744764    0.002719109  0.006031426
## avg_glucose_level  0.11248427   -0.065479200  0.094897697
## bmi                1.00000000   -0.282826306  0.064070432
## smoking_status    -0.28282631    1.000000000 -0.075919784
## stroke             0.06407043   -0.075919784  1.000000000
# corrplot
corrplot(cor, method = "color")

# corrgram
corrgram(data_num, order = TRUE, 
         lower.panel = panel.shade,
         upper.panel = panel.pie,
         text.panel = panel.txt)

# Age Group Distribution
ggplot(data, aes(age)) + 
  geom_bar(aes(fill = age)) +
  scale_fill_brewer(palette = "Set2") +
  geom_text(stat='count', aes(label=..count..), vjust=-0.3) +
  labs(x = "Age", y = "Count", title ="Age Group Distribution")

# age & stroke
ggplot(data, aes(age)) + 
  geom_bar(aes(fill = stroke)) +
  scale_fill_brewer(palette = "Set2") +
  facet_grid(cols = vars(stroke)) +
  geom_text(stat='count', aes(label=..count..), vjust=-0.3) +
  labs(x = "Age", y = "Count", title ="Age Group Distribution with Class Label")

# normalize the height
ggplot(data, aes(age)) + 
  geom_bar(aes(fill = stroke),
           position = "fill",
           alpha=0.8) +
  labs(x = "Age", y = "Scaled Count", title ="Age Distribution with Normalize Height")

# Gender Distribution
ggplot(data, aes(gender)) + 
  geom_bar(aes(fill = gender)) +
  scale_fill_brewer(palette = "Set2") +
  geom_text(stat='count', aes(label=..count..), vjust=-0.3) +
  labs(x = "Gender", y = "Count", title ="Gender Distribution")

# gender & stroke
ggplot(data, aes(gender)) + 
  geom_bar(aes(fill = stroke)) +
  scale_fill_brewer(palette = "Set2") +
  facet_grid(cols = vars(stroke)) +
  geom_text(stat='count', aes(label=..count..), vjust=-0.3) +
  labs(x = "Gender", y = "Count", title ="Gender Distribution with Class Label")

# normalize the height
ggplot(data, aes(gender)) + 
  geom_bar(aes(fill = stroke),
           position = "fill",
           alpha=0.8) +
  labs(x = "Gender", y = "Scaled Count", title ="Gender Distribution with Normalize Height")

# Hypertension Distribution
ggplot(data, aes(hypertension)) + 
  geom_bar(aes(fill = hypertension)) +
  scale_fill_brewer(palette = "Set2") +
  geom_text(stat='count', aes(label=..count..), vjust=-0.3) +
  labs(x = "Hypertension", y = "Count", title ="Hypertension Distribution")

# hypertension & stroke
ggplot(data, aes(hypertension)) + 
  geom_bar(aes(fill = stroke)) +
  scale_fill_brewer(palette = "Set2") +
  facet_grid(cols = vars(stroke)) +
  geom_text(stat='count', aes(label=..count..), vjust=-0.3) +
  labs(x = "Hypertension", y = "Count", title ="Hypertension Distribution with Class Label")

# normalize the height
ggplot(data, aes(hypertension)) + 
  geom_bar(aes(fill = stroke),
           position = "fill",
           alpha=0.8) +
  labs(x = "Hypertension", y = "Scaled Count", title ="Hypertension Distribution with Normalize Height")

# Heart Disease Distribution
ggplot(data, aes(heart_disease)) + 
  geom_bar(aes(fill = heart_disease)) +
  scale_fill_brewer(palette = "Set2") +
  geom_text(stat='count', aes(label=..count..), vjust=-0.3) +
  labs(x = "Heart Disease", y = "Count", title ="Heart Disease Distribution")

# heart_disease & stroke
ggplot(data, aes(heart_disease)) + 
  geom_bar(aes(fill = stroke)) +
  scale_fill_brewer(palette = "Set2") +
  facet_grid(cols = vars(stroke)) +
  geom_text(stat='count', aes(label=..count..), vjust=-0.3) +
  labs(x = "Heart Disease", y = "Count", title ="Heart Disease Distribution with Class Label")

# normalize the height
ggplot(data, aes(heart_disease)) + 
  geom_bar(aes(fill = stroke),
           position = "fill",
           alpha=0.8) +
  labs(x = "Heart Disease", y = "Scaled Count", title ="Heart Disease Distribution with Normalize Height")

# Marital Status
ggplot(data, aes(ever_married)) + 
  geom_bar(aes(fill = ever_married)) +
  scale_fill_brewer(palette = "Set2") +
  geom_text(stat='count', aes(label=..count..), vjust=-0.3) +
  labs(x = "Marital Status", y = "Count", title ="Marital Status")

# ever_married & stroke
ggplot(data, aes(ever_married)) + 
  geom_bar(aes(fill = stroke)) +
  scale_fill_brewer(palette = "Set2") +
  facet_grid(cols = vars(stroke)) +
  geom_text(stat='count', aes(label=..count..), vjust=-0.3) +
  labs(x = "Marital Status", y = "Count", title ="Marital Status with Class Label")

# normalize the height
ggplot(data, aes(ever_married)) + 
  geom_bar(aes(fill = stroke),
           position = "fill",
           alpha=0.8) +
  labs(x = "Marital Status", y = "Scaled Count", title ="Marital Status with Normalize Height")

# Distribution of Work Type
ggplot(data, aes(work_type)) + 
  geom_bar(aes(fill = work_type)) +
  scale_fill_brewer(palette = "Set2") +
  geom_text(stat='count', aes(label=..count..), vjust=-0.3) +
  labs(x = "Work Type", y = "Count", title ="Distribution of Work Type")

# work_type & stroke
ggplot(data, aes(work_type)) + 
  geom_bar(aes(fill = stroke)) +
  scale_fill_brewer(palette = "Set2") +
  facet_grid(cols = vars(stroke)) +
  geom_text(stat='count', aes(label=..count..), vjust=-0.3) +
  labs(x = "Work Type", y = "Count", title ="Distribution of Work Type with Class Label")

# normalize the height
ggplot(data, aes(work_type)) + 
  geom_bar(aes(fill = stroke),
           position = "fill",
           alpha=0.8) +
  labs(x = "Work Type", y = "Scaled Count", title ="Distribution of Work Type with Normalize Height")

# Distribution of Residence Type
ggplot(data, aes(Residence_type)) + 
  geom_bar(aes(fill = Residence_type)) +
  scale_fill_brewer(palette = "Set2") +
  geom_text(stat='count', aes(label=..count..), vjust=-0.3) +
  labs(x = "Residence Type", y = "Count", title ="Distribution of Residence Type")

# Residence_type & stroke
ggplot(data, aes(Residence_type)) + 
  geom_bar(aes(fill = stroke)) +
  scale_fill_brewer(palette = "Set2") +
  facet_grid(cols = vars(stroke)) +
  geom_text(stat='count', aes(label=..count..), vjust=-0.3) +
  labs(x = "Residence Type", y = "Count", title ="Distribution of Residence Type with Class Label")

# normalize the height
ggplot(data, aes(Residence_type)) + 
  geom_bar(aes(fill = stroke),
           position = "fill",
           alpha=0.8) +
  labs(x = "Residence Type", y = "Scaled Count", title ="Distribution of Residence Type with Normalize Height")

# Group of Average Glucose Level
ggplot(data, aes(avg_glucose_level)) + 
  geom_bar(aes(fill = avg_glucose_level)) +
  scale_fill_brewer(palette = "Set2") +
  geom_text(stat='count', aes(label=..count..), vjust=-0.3) +
  labs(x = "Average Glucose Level", y = "Count", title ="Group of Average Glucose Level")

# avg_glucose_level & stroke
ggplot(data, aes(avg_glucose_level)) + 
  geom_bar(aes(fill = stroke)) +
  scale_fill_brewer(palette = "Set2") +
  facet_grid(cols = vars(stroke)) +
  geom_text(stat='count', aes(label=..count..), vjust=-0.3) +
  labs(x = "Average Glucose Level", y = "Count", title ="Distribution of Average Glucose Level Group with Class Label")

# normalize the height
ggplot(data, aes(avg_glucose_level)) + 
  geom_bar(aes(fill = stroke),
           position = "fill",
           alpha=0.8) +
  labs(x = "Average Glucose Level", y = "Scaled Count", title ="Group of Average Glucose Level with Normalize Height")

# Distribution of BMI Group
ggplot(data, aes(bmi)) + 
  geom_bar(aes(fill = bmi)) +
  scale_fill_brewer(palette = "Set2") +
  geom_text(stat='count', aes(label=..count..), vjust=-0.3) +
  labs(x = "BMI Group", y = "Count", title ="Distribution of BMI Group")

# bmi & stroke
ggplot(data, aes(bmi)) + 
  geom_bar(aes(fill = stroke)) +
  scale_fill_brewer(palette = "Set2") +
  facet_grid(cols = vars(stroke)) +
  geom_text(stat='count', aes(label=..count..), vjust=-0.3) +
  labs(x = "BMI Group", y = "Count", title ="Distribution of BMI Group with Class Label")

# normalize the height
ggplot(data, aes(bmi)) + 
  geom_bar(aes(fill = stroke),
           position = "fill",
           alpha=0.8) +
  labs(x = "BMI Group", y = "Scaled Count", title ="BMI Group with Normalize Height")

# Distribution of Smoking Status
ggplot(data, aes(smoking_status)) + 
  geom_bar(aes(fill = smoking_status)) +
  scale_fill_brewer(palette = "Set2") +
  geom_text(stat='count', aes(label=..count..), vjust=-0.3) +
  labs(x = "Smoking Status", y = "Count", title ="Distribution of Smoking Status")

# smoking_status & stroke
ggplot(data, aes(smoking_status)) + 
  geom_bar(aes(fill = stroke)) +
  scale_fill_brewer(palette = "Set2") +
  facet_grid(cols = vars(stroke)) +
  geom_text(stat='count', aes(label=..count..), vjust=-0.3) +
  labs(x = "Smoking Status", y = "Count", title ="Distribution of Smoking Status with Class Label")

# normalize the height
ggplot(data, aes(smoking_status)) + 
  geom_bar(aes(fill = stroke),
           position = "fill",
           alpha=0.8) +
  labs(x = "Smoking Status", y = "Scaled Count", title ="Distribution of Smoking Status with Normalize Height")

# age, avg_glucose_level
ggplot(data, aes(age)) +
  geom_bar(alpha = 0.8, aes(fill = avg_glucose_level)) +
  facet_grid(rows = vars(avg_glucose_level)) +
  scale_fill_brewer(palette = "Reds") +
  geom_text(stat='count', aes(label=..count..)) +
  labs(x = "Age Group", y = "Count", title ="Distribution of Average Glucose Levels in Different Age Groups")

# age, hypertension
ggplot(data, aes(age)) +
  geom_bar(alpha = 0.8, aes(fill = hypertension)) +
  facet_grid(rows = vars(hypertension)) +
  scale_fill_manual(values = c("skyblue", "royalblue", "blue", "navy")) +
  geom_text(stat='count', aes(label=..count..)) +
  labs(x = "Age Group", y = "Count", title ="Hypertension Status in Different Age Groups")

# age, heart_disease
ggplot(data, aes(age)) +
  geom_bar(alpha = 0.8, aes(fill = heart_disease)) +
  facet_grid(rows = vars(heart_disease)) +
  scale_fill_manual(values = c("skyblue", "royalblue", "blue", "navy")) +
  geom_text(stat='count', aes(label=..count..)) +
  labs(x = "Age Group", y = "Count", title ="Heart Disease Status in Different Age Groups")

# age, bmi
ggplot(data, aes(age)) +
  geom_bar(alpha = 0.8, aes(fill = bmi)) +
  facet_grid(rows = vars(bmi)) +
  scale_fill_brewer(palette = "Reds") +
  geom_text(stat='count', aes(label=..count..)) +
  labs(x = "Age Group", y = "Count", title ="BMI Group in Different Age Groups")

# avg_glucose_level, bmi, stroke
ggplot(data, aes(bmi, avg_glucose_level)) +
  geom_jitter(alpha = 0.6, aes(color = stroke), size =1) +
  facet_grid(rows = vars(stroke)) +
  labs(x = "BMI Group", y = "Group of Average Glucose Level", title ="Distribution of BMI & Average Glucose Level with Class Label")

# hypertension, avg_glucose_level, bmi
ggplot(data, aes(bmi, avg_glucose_level)) +
  geom_jitter(alpha = 0.6, aes(color = hypertension), size =1) +
  facet_grid(rows = vars(hypertension)) +
  labs(x = "BMI Group", y = "Group of Average Glucose Level", title ="Distribution of BMI & Average Glucose Level with Different Hypertension Status")


5. Model 1: Support Vector Machine

# drop uncorrelated attributes: gender, Residence_type
data_drop <- select(data, -gender, -Residence_type)

# Random Over-Sampling
# move class label to 1st row on dataset
data_md <- data_drop[c(9:1)]

# over sampling data
data_os <- ovun.sample(stroke~., data=data_md, method = "over", p = 0.5, seed = 1)

# check the data after over sampling
str(data_os)
## List of 3
##  $ Call  : language ovun.sample(formula = stroke ~ ., data = data_md, method = "over", p = 0.5,      seed = 1)
##  $ method: chr "over"
##  $ data  :'data.frame':  9367 obs. of  9 variables:
##   ..$ stroke           : Factor w/ 2 levels "0","1": 1 1 1 1 1 1 1 1 1 1 ...
##   ..$ smoking_status   : Factor w/ 4 levels "formerly smoked",..: 4 2 4 1 4 4 1 2 3 2 ...
##   ..$ bmi              : Factor w/ 4 levels "underweight",..: 1 4 1 4 2 4 1 3 4 4 ...
##   ..$ avg_glucose_level: Factor w/ 4 levels "low","normal",..: 2 2 3 1 4 4 2 4 2 4 ...
##   ..$ work_type        : Factor w/ 5 levels "children","Govt_job",..: 1 4 4 4 3 4 4 5 4 5 ...
##   ..$ ever_married     : Factor w/ 2 levels "No","Yes": 1 2 1 2 1 2 2 2 2 2 ...
##   ..$ heart_disease    : Factor w/ 2 levels "0","1": 1 1 1 1 1 1 1 2 1 1 ...
##   ..$ hypertension     : Factor w/ 2 levels "0","1": 1 2 1 1 1 1 1 1 1 2 ...
##   ..$ age              : Factor w/ 4 levels "young","grown",..: 1 3 1 4 1 3 3 4 2 4 ...
##  - attr(*, "class")= chr "ovun.sample"
summary(data_os)
## 
## Call: 
## ovun.sample(formula = stroke ~ ., data = data_md, method = "over", 
##     p = 0.5, seed = 1)
## 
## Summary of data balanced by oversampling 
## 
##  stroke           smoking_status          bmi         avg_glucose_level
##  0:4700   formerly smoked:1981   underweight: 348   low        :1159   
##  1:4667   never smoked   :3649   normal     :2012   normal     :3695   
##           smokes         :1580   overweight :2931   prediabetes:1602   
##           Unknown        :2157   obese      :4076   diabetes   :2911   
##                                                                        
##          work_type    ever_married heart_disease hypertension     age      
##  children     : 687   No :2184     0:8259        0:7602       young :1280  
##  Govt_job     :1184   Yes:7183     1:1108        1:1765       grown :1360  
##  Never_worked :  22                                           mature:2380  
##  Private      :5563                                           old   :4347  
##  Self-employed:1911
table(data_os$data$stroke)
## 
##    0    1 
## 4700 4667
# move class label (stroke) back to last row
data_os <- data_os$data[c(9:1)]

# Train and Test Split
set.seed(101)

split <- sample.split(data_os$stroke, SplitRatio = 0.7)
train <- subset(data_os, split == TRUE)
test <- subset(data_os, split == FALSE)
# SVM
model_svm <- svm(stroke ~ ., data = train)

# check the model
summary(model_svm)
## 
## Call:
## svm(formula = stroke ~ ., data = train)
## 
## 
## Parameters:
##    SVM-Type:  C-classification 
##  SVM-Kernel:  radial 
##        cost:  1 
## 
## Number of Support Vectors:  3425
## 
##  ( 1692 1733 )
## 
## 
## Number of Classes:  2 
## 
## Levels: 
##  0 1
# use the model on test data to predict our label (stroke)
pred_svm <- predict(model_svm, test[1:8])

# check the model performance
confusionMatrix(pred_svm, 
                factor(test$stroke), 
                mode = "everything", 
                positive = "1")
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction    0    1
##          0  975  189
##          1  435 1211
##                                           
##                Accuracy : 0.7779          
##                  95% CI : (0.7621, 0.7932)
##     No Information Rate : 0.5018          
##     P-Value [Acc > NIR] : < 2.2e-16       
##                                           
##                   Kappa : 0.5561          
##                                           
##  Mcnemar's Test P-Value : < 2.2e-16       
##                                           
##             Sensitivity : 0.8650          
##             Specificity : 0.6915          
##          Pos Pred Value : 0.7357          
##          Neg Pred Value : 0.8376          
##               Precision : 0.7357          
##                  Recall : 0.8650          
##                      F1 : 0.7951          
##              Prevalence : 0.4982          
##          Detection Rate : 0.4310          
##    Detection Prevalence : 0.5858          
##       Balanced Accuracy : 0.7782          
##                                           
##        'Positive' Class : 1               
## 
# Parameter tuning
# sampling method: 10-fold cross validation
# It takes a long time, so I comment here

# tune.results <- tune(svm, train.x = stroke ~ .,
#                      data = train,
#                      kernel = 'radial',
#                      ranges = list(cost = c(1,10), 
#                                    gamma = c(0.1,1)))
# tune.results

# set cost = 10, gamma = 1
model_svm <- svm(stroke ~ ., data = train,
                 kernel = 'radial',
                 cost = 10,
                 gamma = 1)

# apply the tuned SVM model on test data to predict class label (stroke)
pred_svm <- predict(model_svm,test[1:8])

6. Model 2: K-nearest Neighbors

# KNN
# convert data type from factor to numeric
data_os_num <- data_os %>% mutate(across(where(is.factor),as.numeric))
data_os_num$stroke <- factor(data_os_num$stroke)
str(data_os_num)
## 'data.frame':    9367 obs. of  9 variables:
##  $ age              : num  1 3 1 4 1 3 3 4 2 4 ...
##  $ hypertension     : num  1 2 1 1 1 1 1 1 1 2 ...
##  $ heart_disease    : num  1 1 1 1 1 1 1 2 1 1 ...
##  $ ever_married     : num  1 2 1 2 1 2 2 2 2 2 ...
##  $ work_type        : num  1 4 4 4 3 4 4 5 4 5 ...
##  $ avg_glucose_level: num  2 2 3 1 4 4 2 4 2 4 ...
##  $ bmi              : num  1 4 1 4 2 4 1 3 4 4 ...
##  $ smoking_status   : num  4 2 4 1 4 4 1 2 3 2 ...
##  $ stroke           : Factor w/ 2 levels "1","2": 1 1 1 1 1 1 1 1 1 1 ...
# standardize the dataset except class label (stroke)
data_std <- scale(data_os_num[1:8])
head(data_std)
##           age hypertension heart_disease ever_married  work_type
## 1 -1.90428262   -0.4818205    -0.3662545    -1.813441 -2.3945821
## 2 -0.04243665    2.0752403    -0.3662545     0.551379  0.2379499
## 3 -1.90428262   -0.4818205    -0.3662545    -1.813441  0.2379499
## 4  0.88848633   -0.4818205    -0.3662545     0.551379  0.2379499
## 5 -1.90428262   -0.4818205    -0.3662545    -1.813441 -0.6395607
## 6 -0.04243665   -0.4818205    -0.3662545     0.551379  0.2379499
##   avg_glucose_level        bmi smoking_status
## 1        -0.6404483 -2.4341304      1.4905203
## 2        -0.6404483  0.9685906     -0.3935231
## 3         0.3171063 -2.4341304      1.4905203
## 4        -1.5980030  0.9685906     -1.3355448
## 5         1.2746609 -1.2998901      1.4905203
## 6         1.2746609  0.9685906      1.4905203
# check variance
var(data_std[,8])
## [1] 1
# add label column (stroke) back
data_knn <- cbind(data_std, data_os_num[9])
head(data_knn)
##           age hypertension heart_disease ever_married  work_type
## 1 -1.90428262   -0.4818205    -0.3662545    -1.813441 -2.3945821
## 2 -0.04243665    2.0752403    -0.3662545     0.551379  0.2379499
## 3 -1.90428262   -0.4818205    -0.3662545    -1.813441  0.2379499
## 4  0.88848633   -0.4818205    -0.3662545     0.551379  0.2379499
## 5 -1.90428262   -0.4818205    -0.3662545    -1.813441 -0.6395607
## 6 -0.04243665   -0.4818205    -0.3662545     0.551379  0.2379499
##   avg_glucose_level        bmi smoking_status stroke
## 1        -0.6404483 -2.4341304      1.4905203      1
## 2        -0.6404483  0.9685906     -0.3935231      1
## 3         0.3171063 -2.4341304      1.4905203      1
## 4        -1.5980030  0.9685906     -1.3355448      1
## 5         1.2746609 -1.2998901      1.4905203      1
## 6         1.2746609  0.9685906      1.4905203      1
# train and test split for KNN model
set.seed(101)
split_knn <- sample.split(data_knn$stroke, SplitRatio = 0.7)
train_knn <- subset(data_knn, split_knn == TRUE)
test_knn <- subset(data_knn, split_knn == FALSE)

# build KNN model
pred_knn <- knn(train_knn[1:8],
                test_knn[1:8],
                train_knn$stroke,
                k = 1)

# check the error rate
er_knn <- mean(test_knn$stroke != pred_knn)
er_knn
## [1] 0.09644128
# Parameter tuning
for (i in 1:10){
  set.seed(101)
  pred_knn <- knn(train_knn[1:8],
                  test_knn[1:8],
                  train_knn$stroke,
                  k=i)
  er_knn[i] <- mean(test_knn$stroke != pred_knn)
}

# elbow method
k <- 1:10
(df <- data.frame(er_knn, k))
##        er_knn  k
## 1  0.09644128  1
## 2  0.10640569  2
## 3  0.11886121  3
## 4  0.12669039  4
## 5  0.13594306  5
## 6  0.14555160  6
## 7  0.15053381  7
## 8  0.15836299  8
## 9  0.16298932  9
## 10 0.16548043 10
ggplot(df, aes(k, er_knn)) + 
  geom_point() + 
  geom_line(lty="dotted",
            color="blue")

# set k = 1
pred_knn <- knn(train_knn[1:8],
                test_knn[1:8],
                train_knn$stroke,
                k = 1)

7. Logistic Regression

# Logistic Regression 
model_log <- glm(formula = stroke ~ .,
                 family = binomial(logit), 
                 data = train)
summary(model_log)
## 
## Call:
## glm(formula = stroke ~ ., family = binomial(logit), data = train)
## 
## Deviance Residuals: 
##      Min        1Q    Median        3Q       Max  
## -2.36507  -0.64837  -0.00011   0.82444   2.61605  
## 
## Coefficients:
##                               Estimate Std. Error z value Pr(>|z|)    
## (Intercept)                   -4.96714    0.47134 -10.538  < 2e-16 ***
## agegrown                      15.64066  201.13166   0.078 0.938016    
## agemature                     17.40326  201.13165   0.087 0.931048    
## ageold                        18.36437  201.13165   0.091 0.927250    
## hypertension1                  0.85271    0.08434  10.110  < 2e-16 ***
## heart_disease1                 0.56515    0.10699   5.282 1.27e-07 ***
## ever_marriedYes               -0.13441    0.10377  -1.295 0.195253    
## work_typeGovt_job            -14.53861  201.13186  -0.072 0.942376    
## work_typeNever_worked        -14.15892  945.74134  -0.015 0.988055    
## work_typePrivate             -14.25017  201.13184  -0.071 0.943517    
## work_typeSelf-employed       -14.39587  201.13185  -0.072 0.942941    
## avg_glucose_levelnormal        0.20099    0.10210   1.969 0.048999 *  
## avg_glucose_levelprediabetes   0.20382    0.11646   1.750 0.080104 .  
## avg_glucose_leveldiabetes      0.62338    0.10666   5.845 5.08e-09 ***
## bminormal                      1.49620    0.41912   3.570 0.000357 ***
## bmioverweight                  1.55694    0.41878   3.718 0.000201 ***
## bmiobese                       1.54556    0.41865   3.692 0.000223 ***
## smoking_statusnever smoked    -0.15925    0.08093  -1.968 0.049086 *  
## smoking_statussmokes           0.22365    0.09910   2.257 0.024020 *  
## smoking_statusUnknown          0.03292    0.09847   0.334 0.738190    
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance: 9089.9  on 6556  degrees of freedom
## Residual deviance: 6319.6  on 6537  degrees of freedom
## AIC: 6359.6
## 
## Number of Fisher Scoring iterations: 16
# apply the model on test dataset
pred_log <- predict(model_log,
                    newdata = test,
                    type='response')

res_log <- ifelse(pred_log > 0.5, 1, 0)

8. Model Evaluation

# Evaluate the performance of models

## 1. SVM model
## ROC curve
roc.curve(pred_svm, test$stroke)

## Area under the curve (AUC): 0.918
## plot Confusion Matrix and evaluation metrics
confusionMatrix(pred_svm, 
                factor(test$stroke), 
                mode = "everything", 
                positive = "1")
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction    0    1
##          0 1185   27
##          1  225 1373
##                                           
##                Accuracy : 0.9103          
##                  95% CI : (0.8991, 0.9206)
##     No Information Rate : 0.5018          
##     P-Value [Acc > NIR] : < 2.2e-16       
##                                           
##                   Kappa : 0.8207          
##                                           
##  Mcnemar's Test P-Value : < 2.2e-16       
##                                           
##             Sensitivity : 0.9807          
##             Specificity : 0.8404          
##          Pos Pred Value : 0.8592          
##          Neg Pred Value : 0.9777          
##               Precision : 0.8592          
##                  Recall : 0.9807          
##                      F1 : 0.9159          
##              Prevalence : 0.4982          
##          Detection Rate : 0.4886          
##    Detection Prevalence : 0.5687          
##       Balanced Accuracy : 0.9106          
##                                           
##        'Positive' Class : 1               
## 
## 2. KNN model
## ROC curve
roc.curve(pred_knn, test_knn$stroke)

## Area under the curve (AUC): 0.913
## plot Confusion Matrix and evaluation metrics
confusionMatrix(pred_knn, 
                factor(test_knn$stroke), 
                mode = "everything", 
                positive="1")
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction    1    2
##          1 1166   27
##          2  244 1373
##                                          
##                Accuracy : 0.9036         
##                  95% CI : (0.892, 0.9142)
##     No Information Rate : 0.5018         
##     P-Value [Acc > NIR] : < 2.2e-16      
##                                          
##                   Kappa : 0.8072         
##                                          
##  Mcnemar's Test P-Value : < 2.2e-16      
##                                          
##             Sensitivity : 0.8270         
##             Specificity : 0.9807         
##          Pos Pred Value : 0.9774         
##          Neg Pred Value : 0.8491         
##               Precision : 0.9774         
##                  Recall : 0.8270         
##                      F1 : 0.8959         
##              Prevalence : 0.5018         
##          Detection Rate : 0.4149         
##    Detection Prevalence : 0.4246         
##       Balanced Accuracy : 0.9038         
##                                          
##        'Positive' Class : 1              
## 
## 3. Logistic Regression model
## ROC curve
roc.curve(res_log, test$stroke)

## Area under the curve (AUC): 0.773
## plot Confusion Matrix and evaluation metrics
confusionMatrix(factor(res_log), 
                factor(test$stroke), 
                mode = "everything", 
                positive="1")
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction    0    1
##          0  998  237
##          1  412 1163
##                                          
##                Accuracy : 0.769          
##                  95% CI : (0.753, 0.7845)
##     No Information Rate : 0.5018         
##     P-Value [Acc > NIR] : < 2.2e-16      
##                                          
##                   Kappa : 0.5383         
##                                          
##  Mcnemar's Test P-Value : 8.486e-12      
##                                          
##             Sensitivity : 0.8307         
##             Specificity : 0.7078         
##          Pos Pred Value : 0.7384         
##          Neg Pred Value : 0.8081         
##               Precision : 0.7384         
##                  Recall : 0.8307         
##                      F1 : 0.7818         
##              Prevalence : 0.4982         
##          Detection Rate : 0.4139         
##    Detection Prevalence : 0.5605         
##       Balanced Accuracy : 0.7693         
##                                          
##        'Positive' Class : 1              
## 

9. Conclusion